{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "945435b4",
   "metadata": {},
   "source": [
    "# Pytorch-Lightning Integration for DeepChem Models\n",
    "\n",
    "\n",
    "In this tutorial we will go through how to setup a deepchem model inside the [pytorch-lightning](https://www.pytorchlightning.ai/) framework. Lightning is a pytorch framework which simplifies the process of experimenting with pytorch models easier. A few key functionalities offered by pytorch lightning which deepchem users can find useful are:\n",
    "\n",
    "1. Multi-gpu training functionalities: pytorch-lightning provides easy multi-gpu, multi-node training. It also simplifies the process of launching multi-gpu, multi-node jobs across different cluster infrastructure, e.g. AWS, slurm based clusters.\n",
    "\n",
    "1. Reducing boilerplate pytorch code: lightning takes care of details like, `optimizer.zero_grad(), model.train(), model.eval()`. Lightning also provides experiment logging functionality, for e.g. irrespective of training on CPU, GPU, multi-nodes the user can use the method `self.log` inside the trainer and it will appropriately log the metrics.\n",
    "1. Features that can speed up training: half-precision training, gradient checkpointing, code profiling.\n",
    "\n",
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1VVLqq0vMlPkSEXeqcFnHY_zEuvDOQu50?usp=sharing)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bbb99230",
   "metadata": {},
   "source": [
    "## Setup\n",
    "\n",
    "- This notebook assumes that you have already installed deepchem, if you have not follow the instructions at the deepchem installation page: https://deepchem.readthedocs.io/en/latest/get_started/installation.html.\n",
    "- Install pytorch lightning following the instructions on lightning's home page: https://www.pytorchlightning.ai/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ba6b1b0f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: deepchem in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (2.6.1.dev20220119163852)\n",
      "Requirement already satisfied: numpy>=1.21 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from deepchem) (1.22.0)\n",
      "Requirement already satisfied: scikit-learn in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from deepchem) (1.0.2)\n",
      "Requirement already satisfied: pandas in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from deepchem) (1.4.0)\n",
      "Collecting rdkit-pypi\n",
      "  Downloading rdkit_pypi-2021.9.5.1-cp38-cp38-macosx_11_0_arm64.whl (15.9 MB)\n",
      "\u001b[K     |████████████████████████████████| 15.9 MB 6.8 MB/s eta 0:00:01\n",
      "\u001b[?25hRequirement already satisfied: joblib in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from deepchem) (1.1.0)\n",
      "Requirement already satisfied: scipy in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from deepchem) (1.7.3)\n",
      "Requirement already satisfied: threadpoolctl>=2.0.0 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from scikit-learn->deepchem) (3.0.0)\n",
      "Requirement already satisfied: python-dateutil>=2.8.1 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from pandas->deepchem) (2.8.2)\n",
      "Requirement already satisfied: pytz>=2020.1 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from pandas->deepchem) (2021.3)\n",
      "Requirement already satisfied: Pillow in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from rdkit-pypi->deepchem) (8.4.0)\n",
      "Requirement already satisfied: six>=1.5 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from python-dateutil>=2.8.1->pandas->deepchem) (1.16.0)\n",
      "Installing collected packages: rdkit-pypi\n",
      "Successfully installed rdkit-pypi-2021.9.5.1\n",
      "Requirement already satisfied: pytorch_lightning in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (1.5.8)\n",
      "Requirement already satisfied: typing-extensions in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from pytorch_lightning) (4.0.1)\n",
      "Requirement already satisfied: numpy>=1.17.2 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from pytorch_lightning) (1.22.0)\n",
      "Requirement already satisfied: torch>=1.7.* in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from pytorch_lightning) (1.10.2)\n",
      "Requirement already satisfied: tensorboard>=2.2.0 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from pytorch_lightning) (2.7.0)\n",
      "Requirement already satisfied: tqdm>=4.41.0 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from pytorch_lightning) (4.62.3)\n",
      "Requirement already satisfied: fsspec[http]!=2021.06.0,>=2021.05.0 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from pytorch_lightning) (2022.1.0)\n",
      "Requirement already satisfied: packaging>=17.0 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from pytorch_lightning) (21.3)\n",
      "Requirement already satisfied: PyYAML>=5.1 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from pytorch_lightning) (6.0)\n",
      "Requirement already satisfied: pyDeprecate==0.3.1 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from pytorch_lightning) (0.3.1)\n",
      "Processing /Users/princychahal/Library/Caches/pip/wheels/8e/70/28/3d6ccd6e315f65f245da085482a2e1c7d14b90b30f239e2cf4/future-0.18.2-py3-none-any.whl\n",
      "Requirement already satisfied: torchmetrics>=0.4.1 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from pytorch_lightning) (0.7.0)\n",
      "Requirement already satisfied: tensorboard-data-server<0.7.0,>=0.6.0 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from tensorboard>=2.2.0->pytorch_lightning) (0.6.0)\n",
      "Requirement already satisfied: absl-py>=0.4 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from tensorboard>=2.2.0->pytorch_lightning) (1.0.0)\n",
      "Requirement already satisfied: grpcio>=1.24.3 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from tensorboard>=2.2.0->pytorch_lightning) (1.43.0)\n",
      "Requirement already satisfied: requests<3,>=2.21.0 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from tensorboard>=2.2.0->pytorch_lightning) (2.27.1)\n",
      "Requirement already satisfied: google-auth<3,>=1.6.3 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from tensorboard>=2.2.0->pytorch_lightning) (2.3.3)\n",
      "Requirement already satisfied: wheel>=0.26 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from tensorboard>=2.2.0->pytorch_lightning) (0.37.1)\n",
      "Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from tensorboard>=2.2.0->pytorch_lightning) (0.4.6)\n",
      "Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from tensorboard>=2.2.0->pytorch_lightning) (1.8.1)\n",
      "Requirement already satisfied: setuptools>=41.0.0 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from tensorboard>=2.2.0->pytorch_lightning) (60.5.0)\n",
      "Requirement already satisfied: werkzeug>=0.11.15 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from tensorboard>=2.2.0->pytorch_lightning) (2.0.2)\n",
      "Requirement already satisfied: markdown>=2.6.8 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from tensorboard>=2.2.0->pytorch_lightning) (3.3.6)\n",
      "Requirement already satisfied: protobuf>=3.6.0 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from tensorboard>=2.2.0->pytorch_lightning) (3.18.1)\n",
      "Requirement already satisfied: aiohttp; extra == \"http\" in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning) (3.8.1)\n",
      "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from packaging>=17.0->pytorch_lightning) (3.0.7)\n",
      "Requirement already satisfied: six in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from absl-py>=0.4->tensorboard>=2.2.0->pytorch_lightning) (1.16.0)\n",
      "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from requests<3,>=2.21.0->tensorboard>=2.2.0->pytorch_lightning) (1.26.8)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from requests<3,>=2.21.0->tensorboard>=2.2.0->pytorch_lightning) (2021.10.8)\n",
      "Requirement already satisfied: charset-normalizer~=2.0.0; python_version >= \"3\" in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from requests<3,>=2.21.0->tensorboard>=2.2.0->pytorch_lightning) (2.0.10)\n",
      "Requirement already satisfied: idna<4,>=2.5; python_version >= \"3\" in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from requests<3,>=2.21.0->tensorboard>=2.2.0->pytorch_lightning) (3.3)\n",
      "Requirement already satisfied: cachetools<5.0,>=2.0.0 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from google-auth<3,>=1.6.3->tensorboard>=2.2.0->pytorch_lightning) (4.2.4)\n",
      "Requirement already satisfied: pyasn1-modules>=0.2.1 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from google-auth<3,>=1.6.3->tensorboard>=2.2.0->pytorch_lightning) (0.2.7)\n",
      "Requirement already satisfied: rsa<5,>=3.1.4; python_version >= \"3.6\" in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from google-auth<3,>=1.6.3->tensorboard>=2.2.0->pytorch_lightning) (4.8)\n",
      "Requirement already satisfied: requests-oauthlib>=0.7.0 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard>=2.2.0->pytorch_lightning) (1.3.0)\n",
      "Requirement already satisfied: importlib-metadata>=4.4; python_version < \"3.10\" in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from markdown>=2.6.8->tensorboard>=2.2.0->pytorch_lightning) (4.10.1)\n",
      "Requirement already satisfied: aiosignal>=1.1.2 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from aiohttp; extra == \"http\"->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning) (1.2.0)\n",
      "Requirement already satisfied: frozenlist>=1.1.1 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from aiohttp; extra == \"http\"->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning) (1.2.0)\n",
      "Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from aiohttp; extra == \"http\"->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning) (4.0.2)\n",
      "Requirement already satisfied: multidict<7.0,>=4.5 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from aiohttp; extra == \"http\"->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning) (6.0.2)\n",
      "Requirement already satisfied: yarl<2.0,>=1.0 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from aiohttp; extra == \"http\"->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning) (1.7.2)\n",
      "Requirement already satisfied: attrs>=17.3.0 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from aiohttp; extra == \"http\"->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch_lightning) (21.4.0)\n",
      "Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard>=2.2.0->pytorch_lightning) (0.4.8)\n",
      "Requirement already satisfied: oauthlib>=3.0.0 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard>=2.2.0->pytorch_lightning) (3.1.1)\n",
      "Requirement already satisfied: zipp>=0.5 in /Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages (from importlib-metadata>=4.4; python_version < \"3.10\"->markdown>=2.6.8->tensorboard>=2.2.0->pytorch_lightning) (3.7.0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Installing collected packages: future\n",
      "Successfully installed future-0.18.2\n"
     ]
    }
   ],
   "source": [
    "!pip install --pre deepchem\n",
    "!pip install pytorch_lightning"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "98ee91c0",
   "metadata": {},
   "source": [
    "Import the relevant packages."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "6890f523",
   "metadata": {},
   "outputs": [],
   "source": [
    "import deepchem as dc\n",
    "from deepchem.models import GCNModel\n",
    "import pytorch_lightning as pl\n",
    "import torch\n",
    "from torch.nn import functional as F\n",
    "from torch import nn\n",
    "import pytorch_lightning as pl\n",
    "from pytorch_lightning.core.lightning import LightningModule\n",
    "from torch.optim import Adam\n",
    "import numpy as np\n",
    "import torch"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fe62d7c3",
   "metadata": {},
   "source": [
    "## Deepchem Example\n",
    "\n",
    "Below we show an example of a Graph Convolution Network (GCN). Note that this is a simple example which uses a GCNModel to predict the label from an input sequence. We do not showcase the complete functionality of deepchem in this example as we want to restructure the deepchem code and adapt it so that it can be easily plugged into pytorch-lightning. This example was inspired from the `GCNModel` documentation present [here](https://github.com/deepchem/deepchem/blob/a68f8c072b80a1bce5671250aef60f9cc8519bec/deepchem/models/torch_models/gcn.py#L200)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2859f97b",
   "metadata": {},
   "source": [
    "**Prepare the dataset**: for training our deepchem models we need a dataset that we can use to train the model. Below we prepare a sample dataset for the purposes of this tutorial. Below we also directly use the featurized to encode examples for the dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "3789e1aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "smiles = [\"C1CCC1\", \"CCC\"]\n",
    "labels = [0., 1.]\n",
    "featurizer = dc.feat.MolGraphConvFeaturizer()\n",
    "X = featurizer.featurize(smiles)\n",
    "dataset = dc.data.NumpyDataset(X=X, y=labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3bc105a0",
   "metadata": {},
   "source": [
    "**Setup the model**: now we initialize the Graph Convolutional Network model that we will use in our training. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "6d8fb2f9",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[16:00:37] /Users/princychahal/Documents/github/dgl/src/runtime/tensordispatch.cc:43: TensorDispatcher: dlopen failed: Using backend: pytorch\n",
      "dlopen(/Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages/dgl-0.8-py3.8-macosx-11.0-arm64.egg/dgl/tensoradapter/pytorch/libtensoradapter_pytorch_1.10.2.dylib, 1): image not found\n"
     ]
    }
   ],
   "source": [
    "model = GCNModel(\n",
    "    mode='classification',\n",
    "    n_tasks=1,\n",
    "    batch_size=2,\n",
    "    learning_rate=0.001\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d5f76e7b",
   "metadata": {},
   "source": [
    "**Train the model**: fit the model on our training dataset, also specify the number of epochs to run."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "4bdb2d8b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.18830760717391967\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages/torch/autocast_mode.py:141: UserWarning: User provided device_type of 'cuda', but CUDA is not available. Disabling\n",
      "  warnings.warn('User provided device_type of \\'cuda\\', but CUDA is not available. Disabling')\n"
     ]
    }
   ],
   "source": [
    "loss = model.fit(dataset, nb_epoch=5)\n",
    "print(loss)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2f74d813",
   "metadata": {},
   "source": [
    "## Pytorch-Lightning + Deepchem example\n",
    "\n",
    "Now we will look at an example of the GCN model adapt for Pytorch-Lightning. For using Pytorch-Lightning there are two important components:\n",
    "1. `LightningDataModule`: This module defines who the data is prepared and fed into the model so that the model can use it for training. The module defines the train dataloader function which are directly used by the trainer to generate data for the `LightningModule`. To learn more about the `LightningDataModule` refer to the [datamodules documentation](https://pytorch-lightning.readthedocs.io/en/stable/extensions/datamodules.html).\n",
    "2. `LightningModule`: This module defines the training, validation steps for our model. We can use this module to initialize our model based on the hyperparameters. There are a number of boilerplate functions which we use directly to track our experiments, for example we can save all the hyperparameters that we used for training using the `self.save_hyperparameters()` method. For more details on how to use this module refer to the [lightningmodules documentation](https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "48b1523b",
   "metadata": {},
   "source": [
    "**Setup the torch dataset**: Note that here we need to create a custome `SmilesDataset` so that we can easily interface with the deepchem featurizers. For this interface we need to define a collate method so that we can create batches for the dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "0c16f761",
   "metadata": {},
   "outputs": [],
   "source": [
    "# prepare LightningDataModule\n",
    "class SmilesDataset(torch.utils.data.Dataset):\n",
    "    def __init__(self, smiles, labels):\n",
    "        assert len(smiles) == len(labels)\n",
    "        featurizer = dc.feat.MolGraphConvFeaturizer()\n",
    "        X = featurizer.featurize(smiles)\n",
    "        self._samples = dc.data.NumpyDataset(X=X, y=labels)\n",
    "        \n",
    "    def __len__(self):\n",
    "        return len(self._samples)\n",
    "        \n",
    "    def __getitem__(self, index):\n",
    "        return (\n",
    "            self._samples.X[index],\n",
    "            self._samples.y[index],\n",
    "            self._samples.w[index],\n",
    "        )\n",
    "    \n",
    "    \n",
    "class SmilesDatasetBatch:\n",
    "    def __init__(self, batch):\n",
    "        X = [np.array([b[0] for b in batch])]\n",
    "        y = [np.array([b[1] for b in batch])]\n",
    "        w = [np.array([b[2] for b in batch])]\n",
    "        self.batch_list = [X, y, w]\n",
    "        \n",
    "        \n",
    "def collate_smiles_dataset_wrapper(batch):\n",
    "    return SmilesDatasetBatch(batch)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bc5c68b8",
   "metadata": {},
   "source": [
    "**Create the GCN specific lightning module**: in this part we use an object of the `SmilesDataset` created above to create the `SmilesDatasetModule`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "44df652a",
   "metadata": {},
   "outputs": [],
   "source": [
    "class SmilesDatasetModule(pl.LightningDataModule):\n",
    "    def __init__(self, train_smiles, train_labels, batch_size):\n",
    "        super().__init__()\n",
    "        self._train_smiles = train_smiles\n",
    "        self._train_labels = train_labels\n",
    "        self._batch_size = batch_size\n",
    "        \n",
    "    def setup(self, stage):\n",
    "        self.train_dataset = SmilesDataset(\n",
    "            self._train_smiles,\n",
    "            self._train_labels,\n",
    "        )\n",
    "        \n",
    "    def train_dataloader(self):\n",
    "        return torch.utils.data.DataLoader(\n",
    "            self.train_dataset,\n",
    "            batch_size=self._batch_size,\n",
    "            collate_fn=collate_smiles_dataset_wrapper,\n",
    "            shuffle=True,  \n",
    "        )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4115ec6e",
   "metadata": {},
   "source": [
    "**Create the lightning module**: in this part we create the GCN specific lightning module. This class specifies the logic flow for the training step. We also create the required models, optimizers and losses for the training flow."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "61fdd620",
   "metadata": {},
   "outputs": [],
   "source": [
    "# prepare the LightningModule\n",
    "class GCNModule(pl.LightningModule):\n",
    "    def __init__(self, mode, n_tasks, learning_rate):\n",
    "        super().__init__()\n",
    "        self.save_hyperparameters(\n",
    "            \"mode\",\n",
    "            \"n_tasks\",\n",
    "            \"learning_rate\",\n",
    "        )\n",
    "        self.gcn_model = GCNModel(\n",
    "            mode=self.hparams.mode,\n",
    "            n_tasks=self.hparams.n_tasks,\n",
    "            learning_rate=self.hparams.learning_rate,\n",
    "        )\n",
    "        self.pt_model = self.gcn_model.model\n",
    "        self.loss = self.gcn_model._loss_fn\n",
    "        \n",
    "    def configure_optimizers(self):\n",
    "        return self.gcn_model.optimizer._create_pytorch_optimizer(\n",
    "            self.pt_model.parameters(),\n",
    "        )\n",
    "    \n",
    "    def training_step(self, batch, batch_idx):\n",
    "        batch = batch.batch_list\n",
    "        inputs, labels, weights = self.gcn_model._prepare_batch(batch)\n",
    "        outputs = self.pt_model(inputs)\n",
    "        \n",
    "        if isinstance(outputs, torch.Tensor):\n",
    "            outputs = [outputs]\n",
    "    \n",
    "        if self.gcn_model._loss_outputs is not None:\n",
    "            outputs = [outputs[i] for i in self.gcn_model._loss_outputs]\n",
    "    \n",
    "        loss_outputs = self.loss(outputs, labels, weights)\n",
    "        \n",
    "        self.log(\n",
    "            \"train_loss\",\n",
    "            loss_outputs,\n",
    "            on_epoch=True,\n",
    "            sync_dist=True,\n",
    "            reduce_fx=\"mean\",\n",
    "            prog_bar=True,\n",
    "        )\n",
    "        \n",
    "        return loss_outputs"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c3b167c5",
   "metadata": {},
   "source": [
    "**Create the relevant objects**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "bfff80bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# create module objects\n",
    "smiles_datasetmodule = SmilesDatasetModule(\n",
    "    train_smiles=[\"C1CCC1\", \"CCC\", \"C1CCC1\", \"CCC\", \"C1CCC1\", \"CCC\", \"C1CCC1\", \"CCC\", \"C1CCC1\", \"CCC\"],\n",
    "    train_labels=[0., 1., 0., 1., 0., 1., 0., 1., 0., 1.],\n",
    "    batch_size=2,\n",
    ")\n",
    "\n",
    "gcnmodule = GCNModule(\n",
    "    mode=\"classification\",\n",
    "    n_tasks=1,\n",
    "    learning_rate=1e-3,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ed902521",
   "metadata": {},
   "source": [
    "## Lightning Trainer\n",
    "\n",
    "Trainer is the wrapper which builds on top of the `LightningDataModule` and `LightningModule`. When constructing the lightning trainer you can also specify the number of epochs, max-steps to run, number of GPUs, number of nodes to be used for trainer. Lightning trainer acts as a wrapper over your distributed training setup and this way you are able to build your models in a way you would build them in a simple way for your local runs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "9e002e3a",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: False, used: False\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n"
     ]
    }
   ],
   "source": [
    "trainer = pl.Trainer(\n",
    "    max_epochs=5,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "870271cd",
   "metadata": {},
   "source": [
    "**Call the fit function to run model training**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "00d35e97",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "  | Name     | Type | Params\n",
      "----------------------------------\n",
      "0 | pt_model | GCN  | 29.4 K\n",
      "----------------------------------\n",
      "29.4 K    Trainable params\n",
      "0         Non-trainable params\n",
      "29.4 K    Total params\n",
      "0.118     Total estimated model params size (MB)\n",
      "/Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages/pytorch_lightning/trainer/data_loading.py:132: UserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n",
      "  rank_zero_warn(\n",
      "/Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages/pytorch_lightning/trainer/data_loading.py:428: UserWarning: The number of training samples (5) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.\n",
      "  rank_zero_warn(\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b2dd36ca21014476917dc7cf922ce6b6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Training: 0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# train\n",
    "trainer.fit(\n",
    "    model=gcnmodule,\n",
    "    datamodule=smiles_datasetmodule,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7537ea2",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
